# R script for replicating simulation analysis in appendix section "Simulation Analysis"
# Zhai & Garside 2022
# original device: MacPro 13, R 4.1.2
# recommended working directory: Desktop

setwd("~/Desktop") # default wkd
rm(list = ls())

# pkgs --------------------------------------------------------------------

if (!require("pacman")) install.packages("pacman")
pacman::p_load(tidyverse, DeclareDesign, fixest, knitr, scales, broom)

# helpers -----------------------------------------------------------------

source("R_00_helpers.R")

# data (spatmat) ----------------------------------------------------------

swms <- readRDS("data_spatmat.RDS")
swms.knn <- swms[["knn"]]
colnames(swms.knn) <- rownames(swms.knn)
nID <- as.numeric(colnames(swms.knn))
set.seed(1)
nW <- sample(nID, size = 5172) %>% as.character
W <- swms.knn[c(nW), c(nW)]

# designer ----------------------------------------------------------------

designer <- function(N_units = 5172, N_periods = 2, tau = 1, beta = 10, corr = 0, p_d1 = 0.1,  tau1 = 0.1, w = W, rho = 0.5) {
    # define population
    declare_population(
      units = add_level(
        N = N_units, # nobs
        U_unit = rnorm(N), # unit fe
        D_unit = rbinom(N, 1, 0.3), # treatment dummy (p=30%)
        D_time = rep(N_periods, N) # treatment time (2)
      ),
      periods = add_level(
        N = N_periods, # n periods
        U_time = rnorm(N), # time fe
        nest = FALSE
      ),
      unit_period = cross_levels(
        by = join(units, periods), # unit x time
        U = rnorm(N), # unit-time effect
        D = if_else(D_unit == 1 & as.numeric(periods) >= D_time, 1, 0), # treatment status (def)
        X = sim_corrv(D, rho = corr), # controls
        D1 = sim_dplus(D, p = p_d1), # treatment +
        R = sim_slag(D, w, periods), # indirect treatment 
        potential_outcomes(
          Y ~ U + U_unit + U_time + D*tau + X*beta + D1*tau1 + R*rho, # y
          conditions = list(D = c(0, 1)) # treatment status 
        )
      )
    ) +
      # define estimand
      declare_inquiries(
        ATT = mean(Y_D_1 - Y_D_0) # att 
      ) +
      # define measurement
      declare_measurement(Y = reveal_outcomes(Y~D)) + # yobs
      # define estimators
      declare_estimator(
        Y~D, # basic model
        cluster = "units", # clustered se 
        fixef = c("units", "periods"), # 2wfw
        model = fixest::feols,
        model_summary = tidy_model,
        inquiry = "ATT",
        label = "DID"
      ) +
      declare_estimator(
        Y~D+X, # full model
        cluster = "units", # clustered se 
        fixef = c("units", "periods"), # 2wfw
        model = fixest::feols,
        model_summary = tidy_model,
        inquiry = "ATT",
        term = "D",
        label = "DIDX"
      ) +
      declare_estimator(
        Y~D+X+D1, # full model + treat+
        cluster = "units", # clustered se 
        fixef = c("units", "periods"), # 2wfw
        model = fixest::feols,
        model_summary = tidy_model,
        inquiry = "ATT",
        term = "D",
        label = "DID+"
      ) +
      declare_estimator(
        Y~D+X+R, # full model + sl
        cluster = "units", # clustered se 
        fixef = c("units", "periods"), # 2wfw
        model = fixest::feols,
        model_summary = tidy_model,
        inquiry = "ATT",
        term = "D",
        label = "SDID"
      )
  }

print(designer)

# simulation --------------------------------------------------------------

att <- draw_estimand(designer())
att.mean <- mean(att$estimand)
att.se <- sd(att$estimand)

att_hat <- draw_estimates(designer())
att_hat.mean.b <- att_hat$estimate[1]
att_hat.mean.f <- att_hat$estimate[2]
att_hat.mean.p <- att_hat$estimate[3]
att_hat.mean.s <- att_hat$estimate[4]
att_hat.se.b <- att_hat$std.error[1]
att_hat.se.f <- att_hat$std.error[2]
att_hat.se.p <- att_hat$std.error[3]
att_hat.se.s <- att_hat$std.error[4]

# table (1st) -------------------------------------------------------------

simest <- data.frame(
  Term = c("Estimand", "Estimate (DID)", "Estimate (DIDX)", "Estimate (DID+)", "Estimate (SDID)"),
  Mean = c(att.mean, att_hat.mean.b, att_hat.mean.f, att_hat.mean.p, att_hat.mean.s),
  SE = c(att.se, att_hat.se.b, att_hat.se.f, att_hat.se.p, att_hat.se.s)
)
simest %>% knitr::kable(caption = "Example of true and estimated ATTs.", digits = 3)

# graph -------------------------------------------------------------------

simdata <- draw_data(design = designer())
simdata <- simdata %>% 
  mutate(DxT = case_when(
    D_unit == 1 & periods == 1 ~ "Y10",
    D_unit == 1 & periods == 2 ~ "Y11",
    D_unit == 0 & periods == 1 ~ "Y00",
    D_unit == 0 & periods == 2 ~ "Y01",
  )) %>% 
  group_by(units) %>% 
  mutate(groupID = max(D)) %>% 
  ungroup() %>% 
  mutate(Y_rsc = Y - min(Y)) 
simdata.sum <- simdata %>% 
  group_by(DxT, groupID, periods) %>% 
  summarise(across(.cols = "Y_rsc", .fns = list(mean=mean, sd=sd)))

ggplot(simdata, aes(x = periods, color = factor(groupID), group = factor(groupID))) +
  geom_point(aes(y = Y_rsc/1e+2), alpha = 0.1, width = 0.2, shape = 1, position = position_jitterdodge(dodge.width = 0.5, jitter.width=0.7)) +
  geom_point(data = simdata.sum, aes(x = periods, y = Y_rsc_mean/1e+2), alpha = 1, size = 6, position = position_dodge(width = 0.5)) +
  geom_linerange(data = simdata.sum, aes(x = periods, y = Y_rsc_mean/1e+2, ymin = (Y_rsc_mean - Y_rsc_sd)/1e+2, ymax = (Y_rsc_mean + Y_rsc_sd)/1e+2), position = position_dodge(width = 0.5)) +
  geom_line(data = simdata.sum, aes(x = periods, y = Y_rsc_mean/1e+2), position = position_dodge(width = 0.5)) +
  scale_color_viridis_d(option = "H", begin = 0.1, end = 0.9, label = c("Untreated", "Treated")) +
  scale_x_discrete(labels = c("Pre", "Post")) +
  scale_y_continuous(labels = scales::percent, breaks = scales::pretty_breaks(n=6)) +
  labs(title = "Simulated DGP in graphical form", 
       caption = "Note: error bars = one standard deviation from mean.",
       x = "Period", y = "Vote Share (%)", 
       color = "T/C Group") +
  theme_classic()
ggsave("~/Desktop/appendix_simdgp.png", device = "png", width = 8, height = 6)

# diagnosis ---------------------------------------------------------------

diagnosands <- declare_diagnosands(
  mean_true = mean(estimand),
  mean_est = mean(estimate),
  bias = mean(estimate - estimand),
  rmse = sqrt(mean((estimate - estimand) ^ 2)),
  power = mean(p.value < 0.05),
  coverage = mean(estimand <= conf.high & estimand >= conf.low)
)

set.seed(1)
diagnosis <- diagnose_design(designer(), diagnosands = diagnosands,
                             sims = 500, bootstrap_sims = FALSE)

# table (2nd) -------------------------------------------------------------

reshape_diagnosis(diagnosis = diagnosis, exclude = c("Design", "Design Label", "Inquiry Label", "Term", "N Sims")) %>%
  rename("E[ATT]"="Mean True", "E[ATT Est.]"="Mean Est") %>%
  knitr::kable(caption = "Example of estimated diagnostic statistics.", digits = 3)

# cleanup -----------------------------------------------------------------

pacman::p_unload("all")
rm(list = ls())
